# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

from torch import nn


def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module,
                                base_model_prefix: str, skip_key: str):
    uninitialized_encoder_weights: List[str] = []
    if decoder.__class__ != encoder.__class__:
        print(f'''{decoder.__class__} and {encoder.__class__} are not equal.
            In this case make sure that
            all encoder weights are correctly initialized.''')

    def tie_encoder_to_decoder_recursively(
        decoder_pointer: nn.Module,
        encoder_pointer: nn.Module,
        module_name: str,
        uninitialized_encoder_weights: List[str],
        skip_key: str,
        depth=0,
    ):
        assert isinstance(decoder_pointer, nn.Module) and isinstance(
            encoder_pointer, nn.Module
        ), f'{decoder_pointer} and {encoder_pointer}' + \
            'have to be of type torch.nn.Module'
        if hasattr(decoder_pointer, 'weight') and skip_key not in module_name:
            assert hasattr(encoder_pointer, 'weight')
            encoder_pointer.weight = decoder_pointer.weight
            if hasattr(decoder_pointer, 'bias'):
                assert hasattr(encoder_pointer, 'bias')
                encoder_pointer.bias = decoder_pointer.bias
            print(module_name + ' is tied')
            return

        encoder_modules = encoder_pointer._modules
        decoder_modules = decoder_pointer._modules
        if len(decoder_modules) > 0:
            assert (len(encoder_modules) >
                    0), f'''Encoder module {encoder_pointer}
            does not match decoder module {decoder_pointer}'''

            all_encoder_weights = set([
                module_name + '/' + sub_name
                for sub_name in encoder_modules.keys()
            ])
            encoder_layer_pos = 0
            for name, module in decoder_modules.items():
                if name.isdigit():
                    encoder_name = str(int(name) + encoder_layer_pos)
                    decoder_name = name
                    if not isinstance(
                            decoder_modules[decoder_name],
                            type(encoder_modules[encoder_name])) and len(
                                encoder_modules) != len(decoder_modules):
                        # this can happen if the name corresponds to
                        # the position in a list module list of layers
                        # in this case the decoder has added a
                        # cross-attention that the encoder doesn't have
                        # thus skip this step and
                        # subtract one layer pos from encoder
                        encoder_layer_pos -= 1
                        continue
                elif name not in encoder_modules:
                    continue
                elif depth > 500:
                    raise ValueError(
                        '''Max depth of recursive function `tie_encoder_to_decoder` reached.
                        It seems that there is a circular dependency
                        between two or more `nn.Modules` of your model.''')
                else:
                    decoder_name = encoder_name = name
                tie_encoder_to_decoder_recursively(
                    decoder_modules[decoder_name],
                    encoder_modules[encoder_name],
                    module_name + '/' + name,
                    uninitialized_encoder_weights,
                    skip_key,
                    depth=depth + 1,
                )
                all_encoder_weights.remove(module_name + '/' + encoder_name)

            uninitialized_encoder_weights += list(all_encoder_weights)

    # tie weights recursively
    tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix,
                                       uninitialized_encoder_weights, skip_key)
